import os
import time
import h5py
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import L1Loss
from sklearn.preprocessing import MinMaxScaler
from scipy.stats import norm, lognorm
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
from torch.utils.data import Dataset, DataLoader
import math
import argparse 
from pathlib import Path
import json
from AdaPlus import AdaPlus
try:
    from torch.utils.tensorboard import SummaryWriter
except Exception:
    SummaryWriter = None
from gru_model_standalone import GRUModel, BasicGRUModel, GRUModelLayerNormChrono, apply_chrono_init, apply_chrono_init_layernorm_chrono_model

# Argument parser
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate for optimizer')
parser.add_argument('--T_0', type=int, default=50, help='First restart period for cosine annealing')
parser.add_argument('--T_mult', type=int, default=2, help='Multiplier for restart period in cosine annealing')
parser.add_argument('--epochs', type=int, default=5000, help='Number of training epochs')
parser.add_argument('--timesteps', type=int, default=160, help='Number of time steps per sequence')
parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--hidden_size', type=int, default=144, help='GRU hidden size')
parser.add_argument('--num_layers', type=int, default=2, help='Number of GRU layers')
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
parser.add_argument('--start_epoch', type=int, default=0, help='only used when resuming from model-only weights; otherwise ignored. ')
parser.add_argument('--resume', type=str, default='', help = 'path to checkpoint(.pth). if it only contains model weights,we still load them.')
parser.add_argument('--logdir', type=str, default='runs/gru', help='Tensorboard log directory')
parser.add_argument('--ckptdir', type=str, default='checkpoints/gru',
                    help='Where to save checkpoints')
parser.add_argument('--save_every', type=int, default=500,
                    help='Save checkpoint every N epochs')
parser.add_argument('--data', type=str, required=True,
                    help='Path to HDF5 dataset (e.g., train_dataset_2_t_p_20000.h5)')
parser.add_argument(
    '--slice_epochs', type=int, nargs='*', default=[1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000],
    help='Epoch milestones at which to snapshot val metrics (1-based)'
)
parser.add_argument(
    '--snapdir', type=str, default=None,
    help='Optional subfolder for metric snapshots (defaults to <output_dir>/snapshots)'
)
parser.add_argument('--optimizer', type=str, default='adamw', choices=['adamw', 'adaplus', 'adam'], help = 'which optimizer to choose' )
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for AdamW/AdaPlus')
parser.add_argument('--beta2', type=float, default=0.999, help='beta1 for AdamW/AdaPlus')
parser.add_argument('--eps', type=float, default=1e-8, help='epsilon for AdamW/AdaPlus')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='weight decay (decoupled) for AdamW/AdaPlus')
parser.add_argument('--amsgrad', action='store_true', help='use AMSGrand variant if supported,  for AdamW/AdaPlus')
parser.add_argument('--model_type', type=str, default='layernorm', choices=['layernorm', 'basic'], help='layernorm: custom LayerNorm GRU; basic: nn.GRU')
parser.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'none'], help='cosine: CosineAnnealingWarmRestarts; none: no scheduler')
parser.add_argument('--chrono_init', action='store_true', help='Enable chrono initialization (basic GRU or LayerNorm GRU chrono variant)')
parser.add_argument('--chrono_tmax', type=int, default=None, help='Tmax for chrono init (default: use --timesteps)')
parser.add_argument('--chrono_tmin', type=float, default=1.0, help='Tmin for chrono init')
parser.add_argument('--chrono_debug', action='store_true', help='print afew chrono_initialized gate biases')



args = parser.parse_args()
def _get_lr(optimizer): 
    for pg in optimizer.param_groups: 
        return pg.get("lr", None)

def save_checkpoint(path, epoch, model, optimizer, scheduler,
                    best_val_mse, best_val_nrmse, hparams=None, extra=None):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_mse': best_val_mse,
        'best_val_nrmse': best_val_nrmse,
        'hparams': hparams or {}
    }
    # Only save scheduler state if we actually have a scheduler
    if scheduler is not None:
        state['scheduler_state_dict'] = scheduler.state_dict()

    if extra:
        state.update(extra)

    torch.save(state, path)

def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location=None):
    ckpt = torch.load(path, map_location=map_location)

    # Accept both full checkpoints and raw model.state_dict() files
    model_state = ckpt.get('model_state_dict', ckpt)
    model.load_state_dict(model_state, strict=False)

    if optimizer is not None and 'optimizer_state_dict' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

    if scheduler is not None and 'scheduler_state_dict' in ckpt:
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])

    start_epoch     = ckpt.get('epoch', 0)
    best_val_mse    = ckpt.get('best_val_mse', float('inf'))
    best_val_nrmse  = ckpt.get('best_val_nrmse', float('inf'))
    return start_epoch, best_val_mse, best_val_nrmse



# Utilities for loading HDF5

def load_hdf5_data(file_path):
    """
    Loads the scaled parameters and normalized currents from an HDF5 file.
    Returns a list of dicts with keys 'filename','Time','scaled_parameters','normalized_current'.
    """
    data = []
    with h5py.File(file_path, 'r') as f:
        for key in f.keys():
            grp = f[key]
            data.append({
                'filename': grp.attrs['filename'],
                'Time': grp['Time'][:],
                'scaled_parameters': grp['scaled_parameters'][:],
                'normalized_current': grp['normalized_current'][:]
            })
    return data



#  Sequence creation

def create_gru_sequences(data, timesteps):
    X, y = [], []
    for sample in data:
        features = sample['scaled_parameters']
        target   = sample['normalized_current']
        # sequence-to-sequence: input T → predict next T
        for i in range(len(target) - timesteps):
            X.append(features[i:i+timesteps])
            y.append(target[i+1:i+1+timesteps])
    return np.array(X), np.array(y)


#  Hyperparameters & device

device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

file_path  = args.data
timesteps  = args.timesteps
batch_size = args.batch_size
hidden_size = args.hidden_size
num_layers  = args.num_layers
dropout     = args.dropout
lr      = args.lr
epochs  = args.epochs

dataset_tag = os.path.splitext(os.path.basename(args.data))[0]
model_tag = 'layerNorm' if args.model_type == 'layernorm' else 'basic_gru'

chrono_tag = 'chrono_off'
if args.chrono_init:
    _tmax_used = args.chrono_tmax if args.chrono_tmax is not None else args.timesteps
    chrono_tag = f'chrono_on_Tmax{_tmax_used}_Tmin{args.chrono_tmin}'

if args.scheduler.lower() == "cosine":
    sched_tag = f"sched_cosine_T0{args.T_0}_Tmult{args.T_mult}"
else:
    sched_tag = "sched_none"

run_name = (
    f"{model_tag}_opt_{args.optimizer}_{sched_tag}_{chrono_tag}_"
    f"lr_{lr}_batch_size_{batch_size}_hidden_size_{hidden_size}_"
    f"num_layers_{num_layers}_dropout_{dropout}_epochs_{epochs}_"
    f"data{dataset_tag}"
)

# Create a directory based on run name

output_dir = run_name
os.makedirs(output_dir, exist_ok=True)

# save hyperparameters to a json file
hparams = vars(args).copy()
hparams.update({
    "dataset_tag": dataset_tag,
    "run_name": run_name, 
    "device"   : str(device)

})

# save to JASON file in the run folder 
hparams_path = os.path.join(output_dir, "hparams.json")
with open(hparams_path, 'w') as f:
    json.dump(hparams, f, indent=2)

print(f"[HParams] Saved hyperparameters to {hparams_path}")


snap_dir = args.snapdir or os.path.join(output_dir, "snapshots")
os.makedirs(snap_dir, exist_ok=True)
slice_set = set(args.slice_epochs or [])

ckpt_dir = os.path.join(args.ckptdir, run_name)
Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

# TensorBoard setup 

tb_dir = os.path.join(args.logdir, run_name)
Path(tb_dir).mkdir(parents=True, exist_ok=True)

writer = SummaryWriter(tb_dir)

# Safe hparams dump now that variables exist
writer.add_text("hparams", json.dumps({
    "hidden_size": hidden_size,
    "num_layers": num_layers,
    "dropout": dropout,
    "lr_init": lr,
    "T_0": args.T_0,
    "T_mult": args.T_mult,
    "timesteps": timesteps,
    "batch_size": batch_size
}, indent=2))

# Custom Scalars layout (linear + log + LR)
layout = {
    "Linear": {
        "MSE (train vs val)":   ["Multiline", ["mse/train",   "mse/val"]],
        "MAE (train vs val)":   ["Multiline", ["mae/train",   "mae/val"]],
        "NRMSE (train vs val)": ["Multiline", ["nrmse/train", "nrmse/val"]],
    },
    "Log10": {
        "log10 MSE":   ["Multiline", ["mse_log10/train",   "mse_log10/val"]],
        "log10 MAE":   ["Multiline", ["mae_log10/train",   "mae_log10/val"]],
        "log10 NRMSE": ["Multiline", ["nrmse_log10/train", "nrmse_log10/val"]],
    },
    "Optimization": {
        "Learning Rate": ["Multiline", ["lr/current"]],
    },
}
writer.add_custom_scalars(layout)



# Load & prepare data

data = load_hdf5_data(file_path)
X, y = create_gru_sequences(data, timesteps)
print(X.shape)  # e.g. (n_samples, 161, 5)
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.1, random_state=42)

# to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32)
X_val_t   = torch.tensor(X_val,   dtype=torch.float32)
y_val_t   = torch.tensor(y_val,   dtype=torch.float32)

# NRMSE denominator 
y_train_flat = y_train_t.reshape(-1)
nrmse_denom = y_train_flat.std(unbiased=False).item()  # population std (ddof=0)

if nrmse_denom < 1e-8: 
    nrmse_denom = 1.0 #final safty to avoid divide by zero 


# loaders
train_ds = TensorDataset(X_train_t, y_train_t)
val_ds   = TensorDataset(X_val_t,   y_val_t)
train_loader = DataLoader(train_ds, batch_size=batch_size,
                          shuffle=True,  num_workers=1)
val_loader   = DataLoader(val_ds,   batch_size=batch_size,
                          shuffle=False, num_workers=1)


# Model, loss, optimizer


tmax_used = args.chrono_tmax if args.chrono_tmax is not None else args.timesteps

if args.model_type == 'basic':
    model = BasicGRUModel(
        input_size=X.shape[2],
        hidden_size=hidden_size,
        num_layers=num_layers,
        dropout=dropout
    ).to(device)

    if args.chrono_init:
        model.gru = apply_chrono_init(
            model.gru,
            Tmax=tmax_used,
            Tmin=args.chrono_tmin,
            debug=args.chrono_debug
        )

else:  # layernorm
    if args.chrono_init:
        model = GRUModelLayerNormChrono(
            input_size=X.shape[2],
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout
        ).to(device)

        model = apply_chrono_init_layernorm_chrono_model(
            model,
            Tmax=tmax_used,
            Tmin=args.chrono_tmin,
            debug=args.chrono_debug
        )
    else:
        #  THIS is  original layernorm model (no chrono)
        model = GRUModel(
            input_size=X.shape[2],
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout
        ).to(device)


mse_crit   = nn.MSELoss()
mae_crit   = L1Loss() 
opt_name = args.optimizer.lower()

if opt_name == 'adaplus':
    optimizer = AdaPlus(
        model.parameters(),
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        eps=args.eps,
        weight_decay=args.weight_decay,
        amsgrad=args.amsgrad
    )
elif opt_name == 'adamw':
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        betas=(args.beta1, args.beta2),
        eps=args.eps,
        weight_decay=args.weight_decay,
        amsgrad=args.amsgrad
    )    
else: 
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
    )

if args.scheduler.lower() == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=args.T_0,
        T_mult=args.T_mult
    )
else: 
    scheduler = None
    
#  Resume logic 
start_epoch = 0  # will shift the training loop start
if args.resume:
    loaded_start, best_val_mse, best_val_nrmse = load_checkpoint(
        args.resume, model, optimizer, scheduler, map_location=device
    )
    # If   from model-only weights (e.g., last.pth) there is no epoch saved.
    # In that case, honor the --start_epoch flag if provided.
    start_epoch = loaded_start if loaded_start > 0 else max(0, args.start_epoch)
    print(f"[Resume] Loaded '{args.resume}'. Starting from epoch {start_epoch+1}.")
else:
    best_val_mse   = float('inf')
    best_val_nrmse = float('inf')


ckpt_hparams = {
    "hidden_size": hidden_size,
    "num_layers": num_layers,
    "dropout": dropout,
    "lr_init": lr,
    "T_0": args.T_0,
    "T_mult": args.T_mult,
    "timesteps": timesteps,
    "batch_size": batch_size,
}



# Training loop

train_mses, val_mses = [], []
train_maes, val_maes = [], []
train_nrmse, val_nrmse = [], []
start_time = time.time()

# Track best validation metrics/epochs 
best_val_mse = float('inf')
best_epoch_mse = -1
best_val_nrmse = float('inf')
best_epoch_nrmse = -1

print(f"Training on device: {device}")
for epoch in range(start_epoch + 1 , epochs + 1): # resume aware
    epoch_start = time.time()  # ← Start timer
    
    model.train()
    train_mse = 0.0
    train_mae = 0.0
    sse_train = 0.0
    n_train = 0.0
    
    
    for Xb, yb in train_loader:
        Xb, yb = Xb.to(device), yb.to(device)
        out = model(Xb)
        mse_loss = mse_crit(out, yb) # batch mean squared error 
        mae_loss = mae_crit(out, yb)

        optimizer.zero_grad()
        mse_loss.backward()
        optimizer.step()
       

        train_mse += mse_loss.item()
        train_mae += mae_loss.item()
        
        # NRMSE metric 
        err = out - yb
        sse_train += (err * err).sum().item()   # sum of squared errors for this batch
        n_train   += err.numel()                # how many scalar targets 
            
        
    
    train_mse /= len(train_loader)
    train_mae /= len(train_loader)

    # epoch RMSE from accumulated SSE/N, then NRMSE (std-normalized)

    rmse_train = math.sqrt(max(sse_train / max(n_train, 1), 0.0) + 1e-12)
    nrmse_train_epoch = rmse_train / max(nrmse_denom, 1e-12)
    
    model.eval()
    val_mse = 0.0
    val_mae = 0.0
    sse_val = 0.0
    n_val = 0.0 
    with torch.no_grad():
        for Xb, yb in val_loader:
            Xb, yb = Xb.to(device), yb.to(device)
            out = model(Xb)
            
            # compute val batch MSE once, reuse for both logging and SSE
            val_mse += mse_crit(out, yb).item()
            val_mae += mae_crit(out, yb).item()

            #  NRMSE metric from scratch
            err = out - yb
            sse_val += (err * err).sum().item()
            n_val   += err.numel()


    val_mse /= len(val_loader)
    val_mae /= len(val_loader)

    # epoch RMSE/NRMSE for validation (use same train-set denominator to avoid leakage)
    rmse_val = math.sqrt(max(sse_val / max(n_val, 1), 0.0) + 1e-12)
    nrmse_val_epoch = rmse_val / max(nrmse_denom, 1e-12)

    #log scalars to TensorBoard 
    if writer:
        # linear scale
        writer.add_scalars("mse",   {"train": train_mse,           "val": val_mse},           epoch)
        writer.add_scalars("mae",   {"train": train_mae,           "val": val_mae},           epoch)
        writer.add_scalars("nrmse", {"train": nrmse_train_epoch,   "val": nrmse_val_epoch},   epoch)
    
        # log10 view (safe, with epsilon)
        #eps = 1e-12
        writer.add_scalars("mse_log10",   {"train": math.log10(train_mse),
                                           "val":   math.log10(val_mse)}, epoch)
        writer.add_scalars("mae_log10",   {"train": math.log10(train_mae),
                                           "val":   math.log10(val_mae)}, epoch)
        writer.add_scalars("nrmse_log10", {"train": math.log10(nrmse_train_epoch),
                                           "val":   math.log10(nrmse_val_epoch)}, epoch)
        
    if scheduler is not None:
        scheduler.step()
        # log LR (after stepping to see the LR for the next epoch)
        if writer:
            curr_lr = _get_lr(optimizer)
            if curr_lr is not None:
                writer.add_scalar("lr/current", curr_lr, epoch)

    
    train_mses.append(train_mse)
    val_mses.append(val_mse)
    train_maes.append(train_mae)
    val_maes.append(val_mae)
    train_nrmse.append(nrmse_train_epoch)   
    val_nrmse.append(nrmse_val_epoch)

    # Save rolling snapshots at milestones: 1..epoch
    if epoch in slice_set:
        # Slices are 0..epoch-1 because lists are 0-based and epoch is 1-based
        np.save(os.path.join(snap_dir, f"val_mse_1to{epoch}.npy"),   np.array(val_mses[:epoch]))
        np.save(os.path.join(snap_dir, f"val_mae_1to{epoch}.npy"),   np.array(val_maes[:epoch]))
        np.save(os.path.join(snap_dir, f"val_nrmse_1to{epoch}.npy"), np.array(val_nrmse[:epoch]))
        print(f"[Snapshot] Saved validation curves 1..{epoch} to {snap_dir}")

    epoch_duration = time.time() - epoch_start  # ← End timer


    print(f"Epoch {epoch}/{epochs} | "
          f"Train MSE: {train_mse:.6f}, MAE: {train_mae:.6f}, NRMSE(mean): {nrmse_train_epoch:.6f} | "  
          f"Val MSE: {val_mse:.6f}, MAE: {val_mae:.6f}, NRMSE(mean): {nrmse_val_epoch:.6f} | "          
          f"Time: {epoch_duration:.2f}s")

    #  Use exact dataset MSE (SSE/N) for checkpointing & tracking 
    mse_val_exact = sse_val / max(n_val, 1)
    #  Checkpoint best-by-MSE 
    if mse_val_exact < best_val_mse:
        best_val_mse = mse_val_exact
        best_epoch_mse = epoch
        torch.save(model.state_dict(), os.path.join(output_dir, "best_by_mse.pth"))
        #print(f"→ New best val MSE {best_val_mse:.6f} at epoch {epoch} (saved best_by_mse.pth)")
        save_checkpoint(
            os.path.join(ckpt_dir, "best_by_mse.pth"),
            epoch, model, optimizer, scheduler,
            best_val_mse, best_val_nrmse,
            hparams=ckpt_hparams
        )
    
    #  Checkpoint best-by-NRMSE (std-normalized) 
    if nrmse_val_epoch < best_val_nrmse:
        best_val_nrmse = nrmse_val_epoch
        best_epoch_nrmse = epoch
        torch.save(model.state_dict(), os.path.join(output_dir, "best_by_nrmse.pth"))
                #print(f"→ New best val NRMSE(std) {best_val_nrmse:.6f} at epoch {epoch} (saved     best_by_nrmse.pth)")
        save_checkpoint(
                    os.path.join(ckpt_dir, "best_by_nrmse.pth"),
                    epoch, model, optimizer, scheduler,
                    best_val_mse, best_val_nrmse,
                    hparams=ckpt_hparams
                )
        
    # --- Save last weights and print a summary of best epochs ---
    torch.save(model.state_dict(), os.path.join(output_dir, "last.pth"))
    
    #print("\n=== Best validation metrics ===")
    #print(f"Best val MSE (exact): {best_val_mse:.6f} at epoch {best_epoch_mse}")
    #print(f"Best val NRMSE(std):  {best_val_nrmse:.6f} at epoch {best_epoch_nrmse}")

    # ==== End-of-training summary & final save ====
    # Save final weights from the last epoch (optional but handy to compare)
    torch.save(model.state_dict(), os.path.join(output_dir, "last.pth"))

    save_checkpoint(
    os.path.join(ckpt_dir, "final.pth"),
    epochs, model, optimizer, scheduler,
    best_val_mse, best_val_nrmse,
    hparams=ckpt_hparams
)


print("\n================ Training complete ================")
print(f"Best-by-MSE   : {best_val_mse:.6f} at epoch {best_epoch_mse} "
          f"-> saved to {os.path.join(output_dir, 'best_by_mse.pth')}")
print(f"Best-by-NRMSE : {best_val_nrmse:.6f} at epoch {best_epoch_nrmse} "
          f"-> saved to {os.path.join(output_dir, 'best_by_nrmse.pth')}")
print(f"Final (last)  : epoch {epochs} -> saved to {os.path.join(output_dir, 'last.pth')}")
print("===================================================\n")

if writer: 
    writer.close()


# Save results

np.save(os.path.join(output_dir, "val_mse_before.npy"), np.array(val_mses))
np.save(os.path.join(output_dir, "val_mae_before.npy"), np.array(val_maes))
np.save(os.path.join(output_dir, "val_nrmse_before.npy"),   np.array(val_nrmse))

        
# Plot for MSE (linear scale)
plt.figure(figsize=(8, 4))
plt.plot(train_mses, label='Train MSE', linestyle='-', color='blue', linewidth=1, marker='o', markersize=4)
plt.plot(val_mses,   label='Val MSE',   linestyle='--', color='red', linewidth=1, marker='x', markersize=4)
plt.title('Mean Squared Error')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.savefig(os.path.join(output_dir, "mse_plot.png"), dpi=300)


#  Plot for MAE (linear scale) 
plt.figure(figsize=(8, 4))
plt.plot(train_maes, label='Train MAE', linestyle='-', color='green', linewidth=1, marker='o', markersize=4)
plt.plot(val_maes,   label='Val MAE',   linestyle='--', color='orange', linewidth=1, marker='x', markersize=4)
plt.title('Mean Absolute Error')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.savefig(os.path.join(output_dir, "mae_plot.png"), dpi=300)

# Plot for NRMSE (linear scale) 

plt.figure(figsize=(8, 4))
plt.plot(train_nrmse, label='Train nrmse', linestyle='-', color='green', linewidth=1, marker='o', markersize=4)
plt.plot(val_nrmse,   label='Val nrmse',   linestyle='--', color='orange', linewidth=1, marker='x', markersize=4)
plt.title('Normalized Root Mean Square Error')
plt.xlabel('Epoch')
plt.ylabel('NRMSE')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.savefig(os.path.join(output_dir, "nrmse_plot.png"), dpi=300)        


# Plot for MSE (log scale) 

plt.figure()
plt.plot(train_mses, label='Train MSE')
plt.plot(val_mses, label='Val MSE')
plt.yscale("log")
plt.title("MSE (Log Scale)")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, "mse_log_plot.png"), dpi=300)

# Plot for MAE (log scale) 

plt.figure()
plt.plot(train_maes, label='Train MAE')
plt.plot(val_maes, label='Val MAE')
plt.yscale("log")
plt.title("MAE (Log Scale)")
plt.xlabel("Epoch")
plt.ylabel("MAE")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, "mae_log_plot.png"), dpi=300)

# Plot for NRMSE (log scale) 

plt.figure()
plt.plot(train_nrmse, label='Train nrmse')
plt.plot(val_nrmse, label='Val nrmse')
plt.yscale("log")
plt.title("NRMSE (Log Scale)")
plt.xlabel("Epoch")
plt.ylabel("NRMSE")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, "nrmse_log_plot.png"), dpi=300)
